package ipratelimit

import (
	"net"
	"runtime/debug"
	"time"

	"gitee.com/clearluo/gotools/util/ipratelimit/internal/typ"

	"gitee.com/clearluo/gotools/util/ipratelimit/internal/blackhouse"
	"gitee.com/clearluo/gotools/util/ipratelimit/internal/ipratelimit"
	"gitee.com/clearluo/gotools/util/ipratelimit/internal/waitblack"
	"gitee.com/clearluo/gotools/util/ipratelimit/internal/whiteip"

	"gitee.com/clearluo/gotools/zaplog"

	"gitee.com/clearluo/gotools/util"
)

type ConfigT struct {
	WhiteIPStrSli        []string      // ip白名单
	RateSpeed            int           // 默认令牌桶产生令牌的速率
	BucketLen            int           // 令牌桶大小
	OnBlackDuration      time.Duration // 触发小黑屋，关多久
	WaitConnReadDuration time.Duration // 套接字等待多久判断空连接
	BlackTimes           int           // 结合intervalSecond参数，代表在intervalSecond的时间间隔内(s)，产生多少次将触发小黑屋
	IntervalSecond       int           // 结合blackTimes

}

var (
	ipRateLimitObj *ipratelimit.IPRateLimiterT
	blackHouseObj  *blackhouse.BlackHouseT
	whiteObj       *whiteip.WhiteIPT
	config         *ConfigT
)

func init() {
	config = &ConfigT{
		RateSpeed:            5,
		BucketLen:            5,
		OnBlackDuration:      time.Minute,
		WaitConnReadDuration: time.Second * 5,
		BlackTimes:           5,
		IntervalSecond:       60,
	}
	ipRateLimitObj = ipratelimit.NewIPRateLimiter(config.RateSpeed, config.BucketLen, config.BlackTimes, config.IntervalSecond)
	blackHouseObj = blackhouse.NewBlackHouseT()
	whiteObj = whiteip.NewWhiteIPT()

	ipInt, _ := typ.IPStr("127.0.0.1").ToIPInt()
	whiteObj.Add(ipInt)
	go printDetail()
}

func printDetail() {
	defer func() {
		if err := recover(); err != nil {
			zaplog.Warnf("printDetail panic:%v,stack:%v", err, string(debug.Stack()))
		}
	}()
	ticker := time.NewTicker(time.Minute)
	for {
		select {
		case <-ticker.C:
			zaplog.Infof("[ipratelimit info]|iprateLen:%v|whiteList:%v|blackHouseList:%v",
				ipRateLimitObj.Len(),
				util.AssertMarshal(whiteObj.List()),
				util.AssertMarshal(blackHouseObj.List()),
			)
		}
	}
}
func Init(cfg *ConfigT) {
	if cfg.RateSpeed > 0 {
		config.RateSpeed = cfg.RateSpeed
	}
	if cfg.BucketLen > 0 {
		config.BucketLen = cfg.BucketLen
	}
	if cfg.OnBlackDuration > 0 {
		config.OnBlackDuration = cfg.OnBlackDuration
	}
	if cfg.WaitConnReadDuration > 0 {
		config.WaitConnReadDuration = cfg.WaitConnReadDuration
	}
	if cfg.BlackTimes > 0 {
		config.BlackTimes = cfg.BlackTimes
	}
	if cfg.IntervalSecond > 0 {
		config.IntervalSecond = cfg.IntervalSecond
	}
	ipRateLimitObj.SetParam(config.RateSpeed, config.BucketLen, config.BlackTimes, config.IntervalSecond)
	for _, v := range cfg.WhiteIPStrSli {
		ipInt, err := typ.IPStr(v).ToIPInt()
		if err != nil {
			zaplog.Warnf("typ[%v] parse err:%v", v, err)
			continue
		}
		whiteObj.Add(ipInt)
	}
}

// 单第二个参数为false，代表ip被限制，不允许进行后续操作；
// 第二个参数为true，可以进行后续业务操作，且必须在规定时间内调用第一个返回值对象的WaitT.Stop()方法，否则将会被纳入黑名单监控
func AllowByConn(conn net.Conn) (*waitblack.WaitT, bool) {
	var err error
	ipTmp, _, err := net.SplitHostPort(conn.RemoteAddr().String())
	if err != nil {
		zaplog.Debugf("conn parse err:%v", err)
		return nil, false
	}
	ipInt, err := typ.IPStr(ipTmp).ToIPInt()
	if err != nil {
		zaplog.Debugf("typ[%v] parse err:%v", ipTmp, err)
		return nil, false
	}
	if whiteObj.IsWhite(ipInt) {
		return newWaitT(conn, ipInt, true), true
	}

	if blackHouseObj.IsBlack(ipInt) {
		zaplog.Debugf("[%v] in black house", ipInt)
		return nil, false
	}
	if !ipRateLimitObj.IPAllow(ipInt) {
		zaplog.Debugf("[%v] typ rate limit", ipInt)
		return nil, false
	}
	return newWaitT(conn, typ.IPInt(ipInt), false), true
}

// 返回值为true代表可以惊喜业务操作，false代表ip被
func AllowByIPStr(ipReq string) bool {
	ipInt, err := typ.IPStr(ipReq).ToIPInt()
	if err != nil {
		return true
	}
	if whiteObj.IsWhite(ipInt) {
		return true
	}

	if blackHouseObj.IsBlack(ipInt) {
		zaplog.Debugf("[%v] in black house", ipInt)
		return false
	}
	if !ipRateLimitObj.IPAllow(ipInt) {
		zaplog.Debugf("[%v] typ rate limit", ipInt)
		return false
	}
	return true
}

// 简单版本，只有ip限流，没用白名单，没有黑名单，返回值为true代表可以进行业务操作，false代表ip被限流
func AllowByIPSimple(ipReq string) bool {
	ipInt, err := typ.IPStr(ipReq).ToIPInt()
	if err != nil {
		return true
	}
	if !ipRateLimitObj.IPAllow(ipInt) {
		zaplog.Debugf("[%v] typ rate limit", ipInt)
		return false
	}
	return true
}

// 将ip加入小黑屋监控，超过一定评率有可能被加入小黑屋
func AddPreBlack(ipReq string) {
	ipInt, err := typ.IPStr(ipReq).ToIPInt()
	if err != nil {
		return
	}
	if whiteObj.IsWhite(ipInt) {
		return
	}
	if !ipRateLimitObj.BlackAllow(ipInt) {
		blackHouseObj.AddBlack(ipInt, config.OnBlackDuration)
	}
}

func newWaitT(conn net.Conn, ipInt typ.IPInt, isWhite bool) *waitblack.WaitT {
	o := &waitblack.WaitT{
		Conn:  conn,
		IPInt: ipInt,
	}
	if !isWhite {
		o.TimerAfterFunc = time.AfterFunc(config.WaitConnReadDuration, func() {
			zaplog.Debugf("conn timeout:%v", ipInt.String())
			if !ipRateLimitObj.BlackAllow(ipInt) {
				blackHouseObj.AddBlack(ipInt, config.OnBlackDuration)
			}
			if conn != nil {
				conn.Close()
			}
		})
	}
	return o
}
